from flask import send_file
from flask_restplus import Namespace, Resource, reqparse
from flask_login import login_required, current_user

import datetime
import logging
from ..util import query_util
import tempfile
import os
import shutil
import xml.etree.ElementTree as ET
import zipfile
import json
from database import (
    ExportModel,
    DatasetModel,
    fix_ids
)

logger = logging.getLogger('gunicorn.error')

api = Namespace('export', description='Export related operations')

annotation_format = reqparse.RequestParser()
annotation_format.add_argument('anno_format', required=False, type=str)

@api.route('/<int:export_id>')
class DatasetExports(Resource):

    @login_required
    def get(self, export_id):
        """ Returns exports """
        export = ExportModel.objects(id=export_id).first()
        if export is None:
            return {"message": "Invalid export ID"}, 400

        dataset = current_user.datasets.filter(id=export.dataset_id).first()
        if dataset is None:
            return {"message": "Invalid dataset ID"}, 400
        
        time_delta = datetime.datetime.utcnow() - export.created_at
        d = fix_ids(export)
        d['ago'] = query_util.td_format(time_delta)
        return d
    
    @login_required
    def delete(self, export_id):
        """ Returns exports """
        export = ExportModel.objects(id=export_id).first()
        if export is None:
            return {"message": "Invalid export ID"}, 400

        dataset = current_user.datasets.filter(id=export.dataset_id).first()
        if dataset is None:
            return {"message": "Invalid dataset ID"}, 400
        
        export.delete()
        return {'success': True}


@api.route('/<int:export_id>/download')
class DatasetExports(Resource):

    @api.expect(annotation_format)
    @login_required
    def get(self, export_id):
        """ Returns exports """
        args = annotation_format.parse_args()
        anno_format = args['anno_format']

        export = ExportModel.objects(id=export_id).first()
        if export is None:
            return {"message": "Invalid export ID"}, 400

        dataset = current_user.datasets.filter(id=export.dataset_id).first()
        if dataset is None:
            return {"message": "Invalid dataset ID"}, 400
        
        if not current_user.can_download(dataset):
            return {"message": "You do not have permission to download the dataset's annotations"}, 403

        if anno_format == 'yolo':
            # 进行文件格式转换
            converted_file_path = self.convert_to_yolo_format(export.path)
            if not converted_file_path:
                return {"message": "Failed to convert file format"}, 500
            # 发送转换后的文件
            return send_file(converted_file_path, attachment_filename=f"{dataset.name}-{'-'.join(export.tags)}.zip"
                             , as_attachment=True)
        elif anno_format == 'voc':  # 添加对Pascal VOC格式的处理
            # 进行文件格式转换
            converted_file_path = self.convert_to_voc_format(export.path)
            if not converted_file_path:
                return {"message": "Failed to convert file format"}, 500
            # 发送转换后的文件
            return send_file(converted_file_path, attachment_filename=f"{dataset.name}-{'-'.join(export.tags)}.zip"
                             , as_attachment=True)
        else:
            return send_file(export.path, attachment_filename=f"{dataset.name}-{'-'.join(export.tags)}.json"
                             , as_attachment=True)


    def convert_to_yolo_format(self, original_file_path):
        """
        将导出数据转换为YOLO格式
        返回转换后文件的路径，如果转换失败则返回None
        """
        try:
            with open(original_file_path, 'r') as json_file:
                coco_data = json.load(json_file)

            # 创建带有特定前缀的临时文件夹来保存YOLO格式的txt文件和图像文件
            yolo_temp_folder = tempfile.mkdtemp(prefix='yolo_conversion_')

            # 创建用于保存.txt文件的文件夹
            txt_folder = os.path.join(yolo_temp_folder, 'labels')
            os.makedirs(txt_folder)

            # 创建用于保存图像文件的文件夹
            image_folder = os.path.join(yolo_temp_folder, 'images')
            os.makedirs(image_folder)

            # 用于保存所有category的字典
            category_info = {}
            # 用于保存所有category_id和category的列表
            categories = []

            for category in coco_data['categories']:
                category_id = category['id']
                category_name = category['name']
                category_info[category_id] = category_name
                categories.append({
                    'id': category['id'],
                    'name': category['name']
                })

            # 将category写入到txt文件中
            category_txt_path = os.path.join(yolo_temp_folder, 'classes.txt')
            with open(category_txt_path, 'w') as category_txt_file:
                for category_id, category_name in category_info.items():
                    category_txt_file.write(f"{category_name}\n")


            # 将category_id和category写入到json文件中
            categories_json_path = os.path.join(yolo_temp_folder, 'categories.json')
            with open(categories_json_path, 'w') as categories_json_file:
                json.dump({'categories': categories}, categories_json_file, indent=2)

            for image in coco_data['images']:
                width = image['width']
                height = image['height']
                file_name = image['file_name']
                file_path = image['path']
                image_name = os.path.splitext(os.path.basename(file_name))[0]
                annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == image['id']]
                yolo_data = []
                for annotation in annotations:
                    category_id = annotation['category_id']
                    bbox = annotation['bbox']
                    x_center = bbox[0] + bbox[2] / 2
                    y_center = bbox[1] + bbox[3] / 2
                    bbox_width = bbox[2]
                    bbox_height = bbox[3]

                    # 将坐标归一化到图像尺寸
                    x_center /= width
                    y_center /= height
                    bbox_width /= width
                    bbox_height /= height

                    # 将数据格式化为YOLO格式，并加上图像路径
                    yolo_line = f"{category_id} {x_center} {y_center} {bbox_width} {bbox_height}\n"
                    yolo_data.append(yolo_line)

                # 保存标注信息到.txt文件
                txt_file_path = os.path.join(txt_folder, f"{image_name}.txt")
                with open(txt_file_path, 'w') as txt_file:
                    txt_file.writelines(yolo_data)

                # 将图像文件复制到临时文件夹中
                shutil.copyfile(file_path, os.path.join(image_folder, os.path.basename(file_name)))

            # 压缩YOLO格式的.txt文件和图像文件
            zip_file_path = os.path.join(tempfile.mkdtemp(), 'yolo_annotations.zip')
            with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
                for root, _, files in os.walk(yolo_temp_folder):
                    for file in files:
                        file_path = os.path.join(root, file)
                        zip_file.write(file_path, os.path.relpath(file_path, yolo_temp_folder))

            return zip_file_path  # 返回压缩文件路径

        except Exception as e:
            logger.error(f"Failed to convert to YOLO format: {e}")
            return None

    def convert_to_voc_format(self, original_file_path):
        """
        将COCO格式的数据转换为Pascal VOC格式
        :param original_file_path: 包含COCO格式数据的文件路径
        :return: 返回转换后文件的路径，如果转换失败则返回None
        """
        try:
            # 读取COCO格式数据
            with open(original_file_path, 'r') as json_file:
                coco_data = json.load(json_file)

            # 创建临时文件夹来保存Pascal VOC格式的数据
            voc_temp_folder = tempfile.mkdtemp(prefix='voc_conversion_')

            # 创建用于保存图像文件的文件夹
            image_folder = os.path.join(voc_temp_folder, 'images')
            os.makedirs(image_folder)

            # 创建用于保存标注文件的文件夹
            annotation_folder = os.path.join(voc_temp_folder, 'Annotations')
            os.makedirs(annotation_folder)

            for image in coco_data['images']:
                image_name = os.path.splitext(os.path.basename(image['file_name']))[0]
                image_width = image['width']
                image_height = image['height']
                image_path = image['path']

                # 复制图像文件到images文件夹中
                shutil.copyfile(image_path, os.path.join(image_folder, f"{image_name}.jpg"))

                # 创建Pascal VOC格式的XML文件
                xml_root = ET.Element('annotation')
                ET.SubElement(xml_root, 'folder').text = 'JPEGImages'
                ET.SubElement(xml_root, 'filename').text = f"{image_name}.jpg"
                size = ET.SubElement(xml_root, 'size')
                ET.SubElement(size, 'width').text = str(image_width)
                ET.SubElement(size, 'height').text = str(image_height)
                ET.SubElement(size, 'depth').text = '3'

                for annotation in coco_data['annotations']:
                    if annotation['image_id'] == image['id']:
                        category_id = annotation['category_id']
                        category_name = coco_data['categories'][category_id]['name']
                        bbox = annotation['bbox']
                        x_min = bbox[0]
                        y_min = bbox[1]
                        x_max = bbox[0] + bbox[2]
                        y_max = bbox[1] + bbox[3]

                        # 创建Pascal VOC格式的object节点
                        obj = ET.SubElement(xml_root, 'object')
                        ET.SubElement(obj, 'name').text = category_name
                        bndbox = ET.SubElement(obj, 'bndbox')
                        ET.SubElement(bndbox, 'xmin').text = str(int(x_min))
                        ET.SubElement(bndbox, 'ymin').text = str(int(y_min))
                        ET.SubElement(bndbox, 'xmax').text = str(int(x_max))
                        ET.SubElement(bndbox, 'ymax').text = str(int(y_max))

                # 保存XML文件
                xml_file_path = os.path.join(annotation_folder, f"{image_name}.xml")
                tree = ET.ElementTree(xml_root)
                tree.write(xml_file_path)

            # 压缩Pascal VOC格式的数据
            zip_file_path = os.path.join(tempfile.mkdtemp(), 'voc_annotations.zip')
            with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
                for root, _, files in os.walk(voc_temp_folder):
                    for file in files:
                        file_path = os.path.join(root, file)
                        zip_file.write(file_path, os.path.relpath(file_path, voc_temp_folder))

            return zip_file_path  # 返回压缩文件路径

        except Exception as e:
            logger.error(f"Failed to convert to Pascal VOC format: {e}")
            return None